{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Quantization-Aware Training with BatchNorm Re-estimation\n",
    "\n",
    "This notebook shows a working code example of how to use AIMET to perform QAT (Quantization-aware training) with batchnorm re-estimation.\n",
    "Batchnorm re-estimation is a technique for countering potential instability of batchnrom statistics (i.e. running mean and variance) during QAT. More specifically, batchnorm re-estimation recalculates the batchnorm statistics based on the model after QAT. By doing so, we aim to make our model learn batchnorm statistics from from stable outputs after QAT, rather than from likely noisy outputs during QAT.\n",
    "\n",
    "#### Overall flow\n",
    "This notebook covers the following steps:\n",
    "1. Create a quantization simulation model with fake quantization ops inserted.\n",
    "2. Finetune and evaluate the quantization simulation model\n",
    "3. Re-estimate batchnorm statistics and compare the eval score before and after re-estimation.\n",
    "4. Fold the re-estimated batchnorm layers and export the quantization simulation model\n",
    "\n",
    "#### What this notebook is not\n",
    "In this notebook, we will focus how to apply batchnorm re-estimation after QAT, rather than covering all the details about QAT itself. For more information about QAT, please refer to [QAT notebook](https://github.com/quic/aimet/blob/develop/Examples/torch/quantization/qat.ipynb) or [QAT range learning notebook](https://github.com/quic/aimet/blob/develop/Examples/torch/quantization/qat_range_learning.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Dataset\n",
    "\n",
    "This notebook relies on the ImageNet dataset for the task of image classification. If you already have a version of the dataset readily available, please use that. Else, please download the dataset from appropriate location (e.g. https://image-net.org/challenges/LSVRC/2012/index.php#).\n",
    "\n",
    "**Note1**: The ImageNet dataset typically has the following characteristics and the dataloader provided in this example notebook rely on these\n",
    "- Subfolders 'train' for the training samples and 'val' for the validation samples. Please see the [pytorch dataset description](https://pytorch.org/vision/0.8/_modules/torchvision/datasets/imagenet.html) for more details.\n",
    "- A subdirectory per class, and a file per each image sample\n",
    "\n",
    "**Note2**: To speed up the execution of this notebook, you may use a reduced subset of the ImageNet dataset. E.g. the entire ILSVRC2012 dataset has 1000 classes, 1000 training samples per class and 50 validation samples per class. But for the purpose of running this notebook, you could perhaps reduce the dataset to say 2 samples per class. This exercise is left upto the reader and is not necessary.\n",
    "\n",
    "Edit the cell below and specify the directory where the downloaded ImageNet dataset is saved."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "DATASET_DIR = '/path/to/dataset/'         # Please replace this with a real directory"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "---\n",
    "## 1. Example evaluation and training pipeline\n",
    "\n",
    "The following is an example training and validation loop for this image classification task.\n",
    "\n",
    "- **Does AIMET have any limitations on how the training, validation pipeline is written?** Not really. We will see later that AIMET will modify the user's model to create a QuantizationSim model which is still a PyTorch model. This QuantizationSim model can be used in place of the original model when doing inference or training.\n",
    "- **Does AIMET put any limitation on the interface of the evaluate() or train() methods?** Not really. You should be able to use your existing evaluate and train routines as-is.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from Examples.common import image_net_config\n",
    "from Examples.torch.utils.image_net_evaluator import ImageNetEvaluator\n",
    "from Examples.torch.utils.image_net_trainer import ImageNetTrainer\n",
    "from Examples.torch.utils.image_net_data_loader import ImageNetDataLoader\n",
    "\n",
    "class ImageNetDataPipeline:\n",
    "\n",
    "    @staticmethod\n",
    "    def get_val_dataloader() -> torch.utils.data.DataLoader:\n",
    "        \"\"\"\n",
    "        Instantiates a validation dataloader for ImageNet dataset and returns it\n",
    "        \"\"\"\n",
    "        data_loader = ImageNetDataLoader(DATASET_DIR,\n",
    "                                         image_size=image_net_config.dataset['image_size'],\n",
    "                                         batch_size=image_net_config.evaluation['batch_size'],\n",
    "                                         is_training=False,\n",
    "                                         num_workers=image_net_config.evaluation['num_workers']).data_loader\n",
    "        return data_loader\n",
    "\n",
    "    @staticmethod\n",
    "    def evaluate(model: torch.nn.Module, use_cuda: bool) -> float:\n",
    "        \"\"\"\n",
    "        Given a torch model, evaluates its Top-1 accuracy on the dataset\n",
    "        :param model: the model to evaluate\n",
    "        :param use_cuda: whether or not the GPU should be used.\n",
    "        \"\"\"\n",
    "        evaluator = ImageNetEvaluator(DATASET_DIR, image_size=image_net_config.dataset['image_size'],\n",
    "                                      batch_size=image_net_config.evaluation['batch_size'],\n",
    "                                      num_workers=image_net_config.evaluation['num_workers'])\n",
    "\n",
    "        return evaluator.evaluate(model, iterations=None, use_cuda=use_cuda)\n",
    "\n",
    "    @staticmethod\n",
    "    def finetune(model: torch.nn.Module, epochs, learning_rate, learning_rate_schedule, use_cuda):\n",
    "        \"\"\"\n",
    "        Given a torch model, finetunes the model to improve its accuracy\n",
    "        :param model: the model to finetune\n",
    "        :param epochs: The number of epochs used during the finetuning step.\n",
    "        :param learning_rate: The learning rate used during the finetuning step.\n",
    "        :param learning_rate_schedule: The learning rate schedule used during the finetuning step.\n",
    "        :param use_cuda: whether or not the GPU should be used.\n",
    "        \"\"\"\n",
    "        trainer = ImageNetTrainer(DATASET_DIR, image_size=image_net_config.dataset['image_size'],\n",
    "                                  batch_size=image_net_config.train['batch_size'],\n",
    "                                  num_workers=image_net_config.train['num_workers'])\n",
    "\n",
    "        trainer.train(model, max_epochs=epochs, learning_rate=learning_rate,\n",
    "                      learning_rate_schedule=learning_rate_schedule, use_cuda=use_cuda)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "---\n",
    "## 2. Load FP32 model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "For this example notebook, we are going to load a pretrained resnet18 model from torchvision. Similarly, you can load any pretrained PyTorch model instead."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from torchvision.models import resnet18\n",
    "from aimet_torch.model_preparer import prepare_model\n",
    "\n",
    "use_cuda = torch.cuda.is_available()\n",
    "if use_cuda:\n",
    "    device = torch.device(\"cuda\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "\n",
    "model = resnet18(pretrained=True).to(device)\n",
    "model = prepare_model(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "---\n",
    "## 3. Create a quantization simulation model and Perform QAT\n",
    "\n",
    "### Create Quantization Sim Model\n",
    "\n",
    "Now we use AIMET to create a QuantizationSimModel. This basically means that AIMET will insert fake quantization ops in the model graph and will configure them.\n",
    "A few of the parameters are explained here\n",
    "- **quant_scheme**: We set this to \"QuantScheme.post_training_tf_enhanced\"\n",
    "    - Supported options are 'tf_enhanced' or 'tf' or using Quant Scheme Enum QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced\n",
    "- **default_output_bw**: Setting this to 8, essentially means that we are asking AIMET to perform all activation quantizations in the model using integer 8-bit precision\n",
    "- **default_param_bw**: Setting this to 8, essentially means that we are asking AIMET to perform all parameter quantizations in the model using integer 8-bit precision\n",
    "\n",
    "There are other parameters that are set to default values in this example. Please check the AIMET API documentation of QuantizationSimModel to see reference documentation for all the parameters.\n",
    "\n",
    "**NOTE**: Note that, unlike in other QAT example scripts, we didn't fold batchnorm layers before QAT. This is because we aim to finetune our model with batchnorm layers present and re-estimate the batchnorm statatistics for better accuracy. The batchnorm layers will be folded after re-estimation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from aimet_common.defs import QuantScheme\n",
    "from aimet_torch.v1.quantsim import QuantizationSimModel\n",
    "\n",
    "dummy_input = torch.rand(1, 3, 224, 224, device=device)    # Shape for each ImageNet sample is (3 channels) x (224 height) x (224 width)\n",
    "\n",
    "sim = QuantizationSimModel(model=model,\n",
    "                           quant_scheme=QuantScheme.training_range_learning_with_tf_init,\n",
    "                           dummy_input=dummy_input,\n",
    "                           default_output_bw=8,\n",
    "                           default_param_bw=8)\n",
    "\n",
    "def pass_calibration_data(sim_model, use_cuda):\n",
    "    data_loader = ImageNetDataPipeline.get_val_dataloader()\n",
    "    batch_size = data_loader.batch_size\n",
    "\n",
    "    if use_cuda:\n",
    "        device = torch.device('cuda')\n",
    "    else:\n",
    "        device = torch.device('cpu')\n",
    "\n",
    "    samples = 1000\n",
    "    batch_cntr = 0\n",
    "\n",
    "    for input_data, target_data in data_loader:\n",
    "        inputs_batch = input_data.to(device)\n",
    "        sim_model(inputs_batch)\n",
    "\n",
    "        batch_cntr += 1\n",
    "        if (batch_cntr * batch_size) > samples:\n",
    "            break\n",
    "                \n",
    "sim.compute_encodings(forward_pass_callback=pass_calibration_data,\n",
    "                      forward_pass_callback_args=use_cuda)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Perform QAT\n",
    "\n",
    "To perform quantization aware training (QAT), we simply train the model for a few more epochs (typically 15-20). As with any training job, hyper-parameters need to be searched for optimal results. Good starting points are to use a learning rate on the same order as the ending learning rate when training the original model, and to drop the learning rate by a factor of 10 every 5 epochs or so.\n",
    "\n",
    "For the purpose of this example notebook, we are going to train only for 1 epoch. But feel free to change these parameters as you see fit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ImageNetDataPipeline.finetune(sim.model, epochs=1, learning_rate=5e-7, learning_rate_schedule=[5, 10], use_cuda=use_cuda)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After we are done with QAT, we can run quantization simulation inference against the validation dataset at the end to observe any improvements in accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "finetuned_accuracy = ImageNetDataPipeline.evaluate(sim.model, use_cuda)\n",
    "print(finetuned_accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## 4. Perform BatchNorm Reestimation\n",
    "\n",
    "### Re-estimate BatchNorm Statistics\n",
    "AIMET provides a helper function, `reestimate_bn_stats`, for re-estimating batchnorm statistics.\n",
    "Here is the full list of parameters for this function:\n",
    "* **model**: Model to re-estimate the BatchNorm statistics.\n",
    "* **dataloader** Train dataloader.\n",
    "* **num_batches** (optional): The number of batches to be used for reestimation. (Default: 100)\n",
    "* **forward_fn** (optional): Optional adapter function that performs forward pass given a model and a input batch yielded from the data loader. If not specified, it is expected that inputs yielded from dataloader can be passed directly to the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from aimet_torch.bn_reestimation import reestimate_bn_stats\n",
    "\n",
    "train_loader = ImageNetDataLoader(images_dir=DATASET_DIR,\n",
    "                                  image_size=image_net_config.dataset['image_size'],\n",
    "                                  batch_size=image_net_config.train['batch_size'],\n",
    "                                  is_training=True,\n",
    "                                  num_workers=image_net_config.train['num_workers']).data_loader\n",
    "def forward_fn(model, inputs):\n",
    "    input_data, target_data = inputs\n",
    "    model(input_data)\n",
    "\n",
    "reestimate_bn_stats(sim.model, train_loader, forward_fn=forward_fn)\n",
    "\n",
    "finetuned_accuracy_bn_reestimated = ImageNetDataPipeline.evaluate(sim.model, use_cuda)\n",
    "print(finetuned_accuracy_bn_reestimated)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fold BatchNorm Layers\n",
    "\n",
    "So far, we have improved our quantization simulation model through QAT and batchnorm re-estimation. The next step would be to actually take this model to target. But first, we should fold the batchnorm layers for our model to run on target devices more efficiently."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from aimet_torch.v1.batch_norm_fold import fold_all_batch_norms_to_scale\n",
    "\n",
    "fold_all_batch_norms_to_scale(sim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## 5. Export Model\n",
    "As the final step, we will export the model to run it on actual target devices. AIMET QuantizationSimModel provides an export API for this purpose."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "os.makedirs('./output/', exist_ok=True)\n",
    "dummy_input = dummy_input.cpu()\n",
    "sim.export(path='./output/', filename_prefix='resnet18_after_qat', dummy_input=dummy_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "Hope this notebook was useful for you to understand how to use batchnorm re-estimation feature of AIMET.\n",
    "\n",
    "Few additional resources\n",
    "- Refer to the [AIMET API docs](https://quic.github.io/aimet-pages/AimetDocs/api_docs/index.html) to know more details of the APIs and optional parameters.\n",
    "- Refer to the [other example notebooks](https://github.com/quic/aimet/tree/develop/Examples/torch/quantization) to understand how to use AIMET post-training quantization techniques and QAT methods."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
